{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Implementing Transformers with PyTorch\n",
    "\n",
    "In this example, we'll implement a transformer from scratch using PyTorch (we won't rely on the default implementation). We'll train the transformer over randomly generated text sequences. This simple task will allow us to focus on the transformer details, rather than a specific problem.\n",
    "\n",
    "_This example is based on_ [https://github.com/harvardnlp/annotated-transformer](https://github.com/harvardnlp/annotated-transformer) <br/>\n",
    "_Copyright (c) 2018 Alexander Rush<br/>\n",
    "Copyright (c) 2019 Ivan Vasilev<br/>\n",
    "License: MIT_\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's start with the imports:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "import math\n",
    "\n",
    "import numpy as np\n",
    "import torch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We'll continue with the implementation of the base single-head attention mechanism:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def attention(query, key, value, mask=None, dropout=None):\n",
    "    \"\"\"Scaled Dot Product Attention\"\"\"\n",
    "    d_k = query.size(-1)\n",
    "\n",
    "    # 1) and 2) Compute the alignment scores with scaling\n",
    "    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)\n",
    "    if mask is not None:\n",
    "        scores = scores.masked_fill(mask == 0, -1e9)\n",
    "\n",
    "    # 3) Compute the attention scores (softmax)\n",
    "    p_attn = torch.nn.functional.softmax(scores, dim=-1)\n",
    "\n",
    "    if dropout is not None:\n",
    "        p_attn = dropout(p_attn)\n",
    "\n",
    "    # 4) Apply the attention scores over the values\n",
    "    return torch.matmul(p_attn, value), p_attn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we'll implement multi-head attention, which uses `attention` internally:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MultiHeadedAttention(torch.nn.Module):\n",
    "    def __init__(self, h, d_model, dropout=0.1):\n",
    "        \"\"\"\n",
    "        :param h: number of heads\n",
    "        :param d_model: query/key/value vector length\n",
    "        \"\"\"\n",
    "        super(MultiHeadedAttention, self).__init__()\n",
    "        assert d_model % h == 0\n",
    "        # We assume d_v always equals d_k\n",
    "        self.d_k = d_model // h\n",
    "        self.h = h\n",
    "\n",
    "        # Create 4 fully connected layers\n",
    "        # 3 for the query/key/value projections\n",
    "        # 1 to concatenate the outputs of all heads\n",
    "        self.fc_layers = clones(torch.nn.Linear(d_model, d_model), 4)\n",
    "        self.attn = None\n",
    "        self.dropout = torch.nn.Dropout(p=dropout)\n",
    "\n",
    "    def forward(self, query, key, value, mask=None):\n",
    "        if mask is not None:\n",
    "            # Same mask applied to all h heads.\n",
    "            mask = mask.unsqueeze(1)\n",
    "\n",
    "        batch_samples = query.size(0)\n",
    "\n",
    "        # 1) Do all the linear projections in batch from d_model => h x d_k\n",
    "        projections = list()\n",
    "        for l, x in zip(self.fc_layers, (query, key, value)):\n",
    "            projections.append(\n",
    "                l(x).view(batch_samples, -1, self.h, self.d_k).transpose(1, 2)\n",
    "            )\n",
    "\n",
    "        query, key, value = projections\n",
    "\n",
    "        # 2) Apply attention on all the projected vectors in batch.\n",
    "        x, self.attn = attention(query, key, value,\n",
    "                                 mask=mask,\n",
    "                                 dropout=self.dropout)\n",
    "\n",
    "        # 3) \"Concat\" using a view and apply a final linear.\n",
    "        x = x.transpose(1, 2).contiguous() \\\n",
    "            .view(batch_samples, -1, self.h * self.d_k)\n",
    "\n",
    "        return self.fc_layers[-1](x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then, we'll implement the `clones` helper function, which allows to copy existing `torch.nn.Module` `n` times:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def clones(module: torch.nn.Module, n: int):\n",
    "    \"\"\"\n",
    "    Produce N identical copies of module in a ModuleList\n",
    "    :param module: The module to be copied.\n",
    "        The module itself is not part of the output module list\n",
    "     :param n: Number of copies\n",
    "    \"\"\"\n",
    "    return torch.nn.ModuleList([copy.deepcopy(module) for _ in range(n)])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We'll continue with several of the smaller building blocks of the transformer encoder and decoder blocks. First, we'll implement the position-wise feedforward network:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "class PositionwiseFFN(torch.nn.Module):\n",
    "    \"\"\"Implements FFN equation from the paper\"\"\"\n",
    "\n",
    "    def __init__(self, d_model: int, d_ff: int, dropout=0.1):\n",
    "        super(PositionwiseFFN, self).__init__()\n",
    "        self.w_1 = torch.nn.Linear(d_model, d_ff)\n",
    "        self.w_2 = torch.nn.Linear(d_ff, d_model)\n",
    "        self.dropout = torch.nn.Dropout(dropout)\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.w_2(self.dropout(torch.nn.functional.relu(self.w_1(x))))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we'll implement the `Embeddings` class, which sits as the first layer of both the encoder and the decoder:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Embeddings(torch.nn.Module):\n",
    "    \"\"\"Encoder/Decoder input embeddings\"\"\"\n",
    "\n",
    "    def __init__(self, d_model, vocab_size):\n",
    "        super(Embeddings, self).__init__()\n",
    "        self.lut = torch.nn.Embedding(vocab_size, d_model)\n",
    "        self.d_model = d_model\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.lut(x) * math.sqrt(self.d_model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We'll continue with the sub-layer residual connection of the transformer blocks:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SublayerConnection(torch.nn.Module):\n",
    "    \"\"\"\n",
    "    A residual connection followed by a layer norm.\n",
    "    Note for code simplicity the norm is first as opposed to last.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, size, dropout):\n",
    "        super(SublayerConnection, self).__init__()\n",
    "        self.norm = LayerNorm(size)\n",
    "        self.dropout = torch.nn.Dropout(dropout)\n",
    "\n",
    "    def forward(self, x, sublayer):\n",
    "        \"\"\"Apply residual connection to any sublayer with the same size.\"\"\"\n",
    "        return x + self.dropout(sublayer(self.norm(x)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then, we'll implement the normalization:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LayerNorm(torch.nn.Module):\n",
    "    \"\"\"Construct a layer normalization module (See citation for details).\"\"\"\n",
    "\n",
    "    def __init__(self, features: int, eps=1e-6):\n",
    "        super(LayerNorm, self).__init__()\n",
    "        self.a_2 = torch.nn.Parameter(torch.ones(features))\n",
    "        self.b_2 = torch.nn.Parameter(torch.zeros(features))\n",
    "        self.eps = eps\n",
    "\n",
    "    def forward(self, x):\n",
    "        mean = x.mean(-1, keepdim=True)\n",
    "        std = x.std(-1, keepdim=True)\n",
    "        return self.a_2 * (x - mean) / (std + self.eps) + self.b_2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We'll also implement the positional encoding of the tokens:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "class PositionalEncoding(torch.nn.Module):\n",
    "    def __init__(self, d_model, dropout, max_len=5000):\n",
    "        super(PositionalEncoding, self).__init__()\n",
    "        self.dropout = torch.nn.Dropout(p=dropout)\n",
    "\n",
    "        # Compute the positional encodings once in log space.\n",
    "        pe = torch.zeros(max_len, d_model)\n",
    "        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)\n",
    "        div_term = torch.exp(torch.arange(0, d_model, 2, dtype=torch.float) *\n",
    "                             -(math.log(10000.0) / d_model))\n",
    "        pe[:, 0::2] = torch.sin(position * div_term)\n",
    "        pe[:, 1::2] = torch.cos(position * div_term)\n",
    "        pe = pe.unsqueeze(0)\n",
    "        self.register_buffer('pe', pe)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = x + torch.autograd.Variable(self.pe[:, :x.size(1)],\n",
    "                                        requires_grad=False)\n",
    "        return self.dropout(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, we have everything necessary to implement one encoder block..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "class EncoderBlock(torch.nn.Module):\n",
    "    \"\"\"Encoder block with self-attention and residual connections\"\"\"\n",
    "\n",
    "    def __init__(self,\n",
    "                 size: int,\n",
    "                 self_attn: MultiHeadedAttention,\n",
    "                 ffn: PositionwiseFFN,\n",
    "                 dropout=0.1):\n",
    "        super(EncoderBlock, self).__init__()\n",
    "        self.self_attn = self_attn\n",
    "        self.ffn = ffn\n",
    "\n",
    "        # Create 2 sub-layer connections\n",
    "        # 1 for the self-attention\n",
    "        # 1 for the FFN\n",
    "        self.sublayers = clones(SublayerConnection(size, dropout), 2)\n",
    "        self.size = size\n",
    "\n",
    "    def forward(self, x, mask):\n",
    "        \"\"\"Self-attention, followed by FFN + residual connections\"\"\"\n",
    "        x = self.sublayers[0](x, lambda x: self.self_attn(x, x, x, mask))\n",
    "        return self.sublayers[1](x, self.ffn)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "... and the encoder itself, which consists of stacked instances of `EncoderBlock`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Encoder(torch.nn.Module):\n",
    "    \"\"\"Transformer encoder with a stack of N blocks\"\"\"\n",
    "\n",
    "    def __init__(self, block: EncoderBlock, N: int):\n",
    "        super(Encoder, self).__init__()\n",
    "        self.blocks = clones(block, N)\n",
    "        self.norm = LayerNorm(block.size)\n",
    "\n",
    "    def forward(self, x, mask):\n",
    "        \"\"\"Iterate over all blocks and normalize\"\"\"\n",
    "        for layer in self.blocks:\n",
    "            x = layer(x, mask)\n",
    "\n",
    "        return self.norm(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's focus on the decoder, starting from the decoder sub-block:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "class DecoderBlock(torch.nn.Module):\n",
    "    \"\"\"One decoder block, composed of self-attention, encoder-attention, and FFN\"\"\"\n",
    "\n",
    "    def __init__(self,\n",
    "                 size: int,\n",
    "                 self_attn: MultiHeadedAttention,\n",
    "                 encoder_attn: MultiHeadedAttention,\n",
    "                 ffn: PositionwiseFFN,\n",
    "                 dropout=0.1):\n",
    "        super(DecoderBlock, self).__init__()\n",
    "        self.size = size\n",
    "        self.self_attn = self_attn\n",
    "        self.encoder_attn = encoder_attn\n",
    "        self.ffn = ffn\n",
    "\n",
    "        # Create 3 sub-layer connections\n",
    "        # 1 for the self-attention\n",
    "        # 1 for the encoder attention\n",
    "        # 1 for the FFN\n",
    "        self.sublayers = clones(SublayerConnection(size, dropout), 3)\n",
    "\n",
    "    def forward(self, x, encoder_states, source_mask, target_mask):\n",
    "        x = self.sublayers[0](x, lambda x: self.self_attn(x, x, x, target_mask))\n",
    "        x = self.sublayers[1](x, lambda x: self.encoder_attn(x, encoder_states, encoder_states, source_mask))\n",
    "        return self.sublayers[2](x, self.ffn)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "... and continuing with the decoder itself:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Decoder(torch.nn.Module):\n",
    "    \"\"\"Generic N layer decoder with masking.\"\"\"\n",
    "\n",
    "    def __init__(self, block: DecoderBlock, N: int, vocab_size: int):\n",
    "        super(Decoder, self).__init__()\n",
    "        self.blocks = clones(block, N)\n",
    "        self.norm = LayerNorm(block.size)\n",
    "        self.projection = torch.nn.Linear(block.size, vocab_size)\n",
    "\n",
    "    def forward(self, x, encoder_states, source_mask, target_mask):\n",
    "        for layer in self.blocks:\n",
    "            x = layer(x, encoder_states, source_mask, target_mask)\n",
    "\n",
    "        x = self.norm(x)\n",
    "\n",
    "        return torch.nn.functional.log_softmax(self.projection(x), dim=-1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, let's implement the full transformer model, which combines the encoder and the decoder:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "class EncoderDecoder(torch.nn.Module):\n",
    "    \"\"\"A Encoder-Decoder architecture\"\"\"\n",
    "\n",
    "    def __init__(self,\n",
    "                 encoder: Encoder,\n",
    "                 decoder: Decoder,\n",
    "                 source_embeddings: torch.nn.Sequential,\n",
    "                 target_embeddings: torch.nn.Sequential):\n",
    "        super(EncoderDecoder, self).__init__()\n",
    "        self.encoder = encoder\n",
    "        self.decoder = decoder\n",
    "        self.source_embeddings = source_embeddings\n",
    "        self.target_embeddings = target_embeddings\n",
    "\n",
    "    def forward(self, source, target, source_mask, target_mask):\n",
    "        \"\"\"Take in and process masked src and target sequences.\"\"\"\n",
    "        encoder_output = self.encoder(\n",
    "            x=self.source_embeddings(source),\n",
    "            mask=source_mask)\n",
    "\n",
    "        return self.decoder(\n",
    "            x=self.target_embeddings(target),\n",
    "            encoder_states=encoder_output,\n",
    "            source_mask=source_mask,\n",
    "            target_mask=target_mask)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, let's define the `build_model` function, which builds and initializes the combined model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_model(source_vocabulary: int,\n",
    "                target_vocabulary: int,\n",
    "                N=6, d_model=512, d_ff=2048, h=8, dropout=0.1):\n",
    "    \"\"\"Build the full transformer model\"\"\"\n",
    "    c = copy.deepcopy\n",
    "    attn = MultiHeadedAttention(h, d_model)\n",
    "    ff = PositionwiseFFN(d_model, d_ff, dropout)\n",
    "    position = PositionalEncoding(d_model, dropout)\n",
    "\n",
    "    model = EncoderDecoder(\n",
    "        encoder=Encoder(EncoderBlock(d_model, c(attn), c(ff), dropout), N),\n",
    "        decoder=Decoder(DecoderBlock(d_model, c(attn), c(attn),\n",
    "                                     c(ff), dropout), N, target_vocabulary),\n",
    "        source_embeddings=torch.nn.Sequential(\n",
    "            Embeddings(d_model, source_vocabulary), c(position)),\n",
    "        target_embeddings=torch.nn.Sequential(\n",
    "            Embeddings(d_model, target_vocabulary), c(position)))\n",
    "\n",
    "    # This was important from their code.\n",
    "    # Initialize parameters with Glorot / fan_avg.\n",
    "    for p in model.parameters():\n",
    "        if p.dim() > 1:\n",
    "            torch.nn.init.xavier_uniform_(p)\n",
    "\n",
    "    return model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We'll continue with some boilerplate code, which will generated `total_samples` random sequences and random labels for them:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "class RandomDataset(torch.utils.data.Dataset):\n",
    "    \"\"\"Random data copy dataset\"\"\"\n",
    "\n",
    "    def __init__(self, V, total_samples, sample_length):\n",
    "        self.samples = list()\n",
    "\n",
    "        sample = dict()\n",
    "        for i in range(total_samples):\n",
    "            data = torch.from_numpy(np.random.randint(1, V, size=(sample_length,)))\n",
    "            data[0] = 1\n",
    "            source = torch.autograd.Variable(data, requires_grad=False)\n",
    "            target = torch.autograd.Variable(data, requires_grad=False)\n",
    "\n",
    "            sample['source'] = source\n",
    "            sample['target'] = target[:-1]\n",
    "            sample['target_y'] = target[1:]\n",
    "            sample['source_mask'] = (source != 0).unsqueeze(-2)\n",
    "            sample['target_mask'] = self.make_std_mask(sample['target'], 0)\n",
    "            sample['tokens_count'] = (sample['target_y'] != 0).data.sum()\n",
    "\n",
    "            self.samples.append(sample)\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.samples)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        return self.samples[idx]\n",
    "\n",
    "    @staticmethod\n",
    "    def make_std_mask(target, pad):\n",
    "        \"\"\"Create a mask to hide padding and future words.\"\"\"\n",
    "        target_mask = (target != pad)\n",
    "        target_mask = target_mask & torch.autograd.Variable(\n",
    "            RandomDataset.subsequent_mask(target.size(-1)).type_as(target_mask.data))\n",
    "\n",
    "        return target_mask\n",
    "\n",
    "    @staticmethod\n",
    "    def subsequent_mask(size):\n",
    "        \"\"\"Mask out subsequent positions.\"\"\"\n",
    "        attn_shape = (size, size)\n",
    "        subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')\n",
    "        return torch.from_numpy(subsequent_mask) == 0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's continue with the transformer training procedure:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_model(model, loss_function, optimizer, data_loader):\n",
    "    # set model to training mode\n",
    "    model.train()\n",
    "\n",
    "    current_loss = 0.0\n",
    "    counter = 0\n",
    "\n",
    "    # iterate over the training data\n",
    "    for i, batch in enumerate(data_loader):\n",
    "        with torch.set_grad_enabled(True):\n",
    "            out = model.forward(batch['source'], batch['target'],\n",
    "                                batch['source_mask'], batch['target_mask'])\n",
    "\n",
    "            loss = loss_function(out.contiguous().view(-1, out.size(-1)),\n",
    "                                 batch['target_y'].contiguous().view(-1))\n",
    "\n",
    "            loss.backward()\n",
    "\n",
    "            optimizer.step()\n",
    "            optimizer.zero_grad()\n",
    "\n",
    "            # statistics\n",
    "            current_loss += loss\n",
    "            counter += 1\n",
    "\n",
    "            if counter % 5 == 0:\n",
    "                print(\"Batch: %d; Loss: %f\" % (i + 1, current_loss / counter))\n",
    "                current_loss = 0.0\n",
    "                counter = 0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, we can put it all together. We'll instantiate the model, we'll generate random dataset, and we'll start the training. Since this task (random sequences and labels) is irrelevant, we'll be only interested to see the decrease of the loss function:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Batch: 5; Loss: 2.789809\n",
      "Batch: 10; Loss: 0.548198\n",
      "Batch: 15; Loss: 0.240894\n",
      "Batch: 20; Loss: 0.161667\n",
      "Batch: 25; Loss: 0.111362\n",
      "Batch: 30; Loss: 0.067419\n",
      "Batch: 35; Loss: 0.027662\n",
      "Batch: 40; Loss: 0.009392\n",
      "Batch: 45; Loss: 0.003070\n",
      "Batch: 50; Loss: 0.002872\n",
      "Batch: 55; Loss: 0.004820\n",
      "Batch: 60; Loss: 0.001508\n",
      "Batch: 65; Loss: 0.001070\n",
      "Batch: 70; Loss: 0.000622\n",
      "Batch: 75; Loss: 0.001101\n",
      "Batch: 80; Loss: 0.000606\n",
      "Batch: 85; Loss: 0.000218\n",
      "Batch: 90; Loss: 0.000161\n",
      "Batch: 95; Loss: 0.000260\n",
      "Batch: 100; Loss: 0.000342\n",
      "Batch: 105; Loss: 0.000628\n",
      "Batch: 110; Loss: 0.000309\n",
      "Batch: 115; Loss: 0.000084\n",
      "Batch: 120; Loss: 0.000161\n",
      "Batch: 125; Loss: 0.000227\n",
      "Batch: 130; Loss: 0.000163\n",
      "Batch: 135; Loss: 0.000390\n",
      "Batch: 140; Loss: 0.000138\n",
      "Batch: 145; Loss: 0.000146\n",
      "Batch: 150; Loss: 0.000050\n",
      "Batch: 155; Loss: 0.000324\n",
      "Batch: 160; Loss: 0.000088\n",
      "Batch: 165; Loss: 0.000059\n",
      "Batch: 170; Loss: 0.000122\n",
      "Batch: 175; Loss: 0.000037\n",
      "Batch: 180; Loss: 0.000381\n",
      "Batch: 185; Loss: 0.000128\n",
      "Batch: 190; Loss: 0.000031\n",
      "Batch: 195; Loss: 0.000244\n",
      "Batch: 200; Loss: 0.000015\n"
     ]
    }
   ],
   "source": [
    "V = 11\n",
    "BATCH_SIZE = 50\n",
    "train_set = RandomDataset(11, BATCH_SIZE * 200, 10)\n",
    "\n",
    "train_loader = torch.utils.data.DataLoader(train_set,\n",
    "                                           batch_size=BATCH_SIZE)\n",
    "\n",
    "model = build_model(V, V, N=2)\n",
    "optimizer = torch.optim.Adam(model.parameters())\n",
    "loss_function = torch.nn.CrossEntropyLoss()\n",
    "\n",
    "train_model(model, loss_function, optimizer, train_loader)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
